import os
import copy
from continual_rl.policies.policy_base import PolicyBase
from continual_rl.policies.impala.impala_policy_config import ImpalaPolicyConfig
from continual_rl.policies.impala.impala_environment_runner import ImpalaEnvironmentRunner
from continual_rl.policies.impala.nets import ImpalaNet
from continual_rl.policies.impala.torchbeast.monobeast import Monobeast
from continual_rl.utils.utils import Utils


class ImpalaPolicy(PolicyBase):
    """
    With IMPALA, the parallelism is the point, so rather than splitting it up into compute_action and train like normal,
    just let the existing IMPALA implementation handle it all.
    这个策略基本上是Monobest对象本身的容器，它保存持久信息（例如，模型和回放缓冲区）.
    """

    def __init__(self, config: ImpalaPolicyConfig, observation_spaces, action_spaces, impala_class: Monobeast = None,
                 policy_net_class: ImpalaNet = None):
        super().__init__(config)
        self._config = config
        self._observation_spaces = observation_spaces
        self._action_spaces = action_spaces

        self.model_flags = self._create_model_flags()

        if impala_class is None:
            self.impala_class = Monobeast
        else:
            self.impala_class = impala_class

        if policy_net_class is None:
            self.policy_net_class = ImpalaNet
        else:
            self.policy_net_class = policy_net_class

        self.impala_trainer = self.impala_class(self.model_flags, self._observation_spaces, self._action_spaces,
                                                self.policy_net_class)

    def _create_model_flags(self):
        """
        Finishes populating the config to contain the rest of the flags used by IMPALA in the creation of the model.
        """
        # torchbeast will change flags, so copy it so config remains unchanged for other tasks.
        flags = copy.deepcopy(self._config)
        flags.savedir = str(self._config.output_dir)
        return flags

    def get_environment_runner(self, task_spec):
        return ImpalaEnvironmentRunner(self._config, self)

    def compute_action(self, observation, task_id, action_space_id, last_timestep_data, eval_mode):
        pass

    def train(self, storage_buffer):
        pass

    def save(self, output_path_dir, cycle_id, task_id, task_total_steps):
        self.impala_trainer.save(output_path_dir)

    def load(self, output_path_dir):
        self.impala_trainer.load(output_path_dir)

    def task_change(self, task_run_id=0):
        if self.config.task_change_reset:
            # 任务切换时创建新的训练器
            self.impala_trainer = self.impala_class(self.model_flags, self._observation_spaces, self._action_spaces,
                                                    self.policy_net_class)
